import os
import time
import re
import uuid
import copy

from random import randint

from collections import Counter

from inspect import signature

import boto3

from fabric import Connection
from botocore.config import Config
from botocore.exceptions import ClientError
from invoke import run
from packaging.version import Version
from packaging.specifiers import SpecifierSet
from tenacity import (
    retry,
    stop_after_attempt,
    stop_after_delay,
    wait_fixed,
    wait_random_exponential,
)

from test.test_utils import (
    get_synapseai_version_from_tag,
    is_deep_canary_context,
    is_pr_context,
    is_mainline_context,
    are_heavy_instance_ec2_tests_enabled,
    login_to_ecr_registry,
    get_account_id_from_image_uri,
    UL_AMI_LIST,
    DEFAULT_REGION,
    P4DE_REGION,
    BENCHMARK_RESULTS_S3_BUCKET,
)
from infra.test_infra.test_infra_utils import create_logger

EC2_INSTANCE_ROLE_NAME = "ec2TestInstanceRole"

# List of instance types for which, if instance spin-up fails, the test is skipped instead of failing.
ICE_SKIP_INSTANCE_LIST = []

# List of instance types which are too powerful for minor tests
HEAVY_INSTANCE_LIST = ["p4d.24xlarge", "p4de.24xlarge", "p5.48xlarge"]

# Flag to enable IPv6 testing
ENABLE_IPV6_TESTING = os.getenv("ENABLE_IPV6_TESTING", "false").lower() == "true"

IPV6_VPC_NAME = os.getenv("IPV6_VPC_NAME")

LOGGER = create_logger(__name__)


def filter_only_multi_gpu(instance_type_list):
    filtered_list = [
        instance_type
        for instance_type in instance_type_list
        if get_instance_num_gpus(instance_type=instance_type) > 1
    ]
    return filtered_list


def filter_only_multi_gpu_and_no_g_type(instance_type_list):
    filtered_list = [
        instance_type
        for instance_type in instance_type_list
        if get_instance_num_gpus(instance_type=instance_type) > 1
        and not instance_type.startswith("g")
    ]
    return filtered_list


def filter_only_single_gpu(instance_type_list):
    filtered_list = [
        instance_type
        for instance_type in instance_type_list
        if get_instance_num_gpus(instance_type=instance_type) == 1
    ]
    return filtered_list


def filter_no_t32x(instance_type_list):
    filtered_list = [
        instance_type for instance_type in instance_type_list if instance_type != "t3.2xlarge"
    ]
    return filtered_list


def is_instance_single_gpu(instance_type):
    return get_instance_num_gpus(instance_type=instance_type) == 1


def is_instance_multi_gpu(instance_type):
    return get_instance_num_gpus(instance_type=instance_type) > 1


def filter_not_heavy_instance_types(instance_type_list):
    filtered_list = [
        instance_type
        for instance_type in instance_type_list
        if instance_type not in HEAVY_INSTANCE_LIST
    ]
    return filtered_list


# https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa.html
# both g4dn and g5.24xlarge we use in RC is not RDMA read supported
# performance test will fail if we use g5.24xlarge
def filter_efa_instance_type(instance_type_list):
    filtered_list = [
        instance_type
        for instance_type in instance_type_list
        if get_num_efa_interfaces_for_instance_type(instance_type)
        and not instance_type.startswith("g4")
        and not instance_type.startswith("g5")
    ]
    return filtered_list


def filter_efa_only_p4_instance_type(instance_type_list):
    filtered_list = [
        instance_type
        for instance_type in instance_type_list
        if get_num_efa_interfaces_for_instance_type(instance_type)
        and instance_type.startswith("p4")
    ]
    return filtered_list


def get_cicd_instance_reserved_region(instance_type):
    return P4DE_REGION if instance_type in ["p4de.24xlarge"] else DEFAULT_REGION


def get_efa_ec2_instance_type(default, filter_function=lambda x: x, job_type=""):
    """
    Helper function wrapping around get_ec2_instance_type to parametrize both ec2_instance_type
    as well as region in cases where certain instance types are reserved in a particular region.
    :param default: Default instance type to use
    :param filter_function: filter_function(instance_type_list) A function that takes the list to be generated by
    the logic of the get_ec2_instance_type function, and filters the list to only produce "acceptable" instances.
    For example, this can be a function that only returns multi-gpu instance types from a given list of instance types.
    :param job_type: str "training"/"inference"/"" as required by the instance-type being tested
    :return: one item list of instance type -- this is used to parametrize tests, and parameter is required to be
    a list.
    """
    instance_list = get_ec2_instance_type(default, "gpu", filter_function, job_type=job_type)
    instance_region_list = [
        (instance_type, get_cicd_instance_reserved_region(instance_type))
        for instance_type in instance_list
    ]
    return instance_region_list


def get_ec2_instance_type(
    default, processor, filter_function=lambda x: x, arch_type="", job_type=""
):
    """
    Get EC2 instance type from associated EC2_[CPU|GPU]_INSTANCE_TYPE env variable, or set it to a
    default for contexts where the variable is not present (i.e. PR, Nightly, local testing)

    :param default: Default instance type to use
    :param processor: "cpu" or "gpu"
    :param filter_function: filter_function(instance_type_list) A function that takes the list to be
    generated by the logic of the get_ec2_instance_type function, and filters the list to only
    produce "acceptable" instances. For example, this can be a function that only returns multi-gpu
    instance types from a given list of instance types.

    :return: one item list of instance type -- this is used to parametrize tests, and parameter is
    required to be a list.
    """
    if is_pr_context() or is_deep_canary_context():
        # This condition filters out instance types that use resources with low-availability, or
        # use very expensive instance types.
        if not are_heavy_instance_ec2_tests_enabled() and default in HEAVY_INSTANCE_LIST:
            return []
        return [default]

    allowed_processors = ("cpu", "gpu", "neuronx", "neuron", "hpu")
    job_type_str = f"_{job_type.upper()}" if job_type else ""
    if processor not in allowed_processors:
        raise RuntimeError(
            f"Aborting EC2 test run. Unrecognized processor type {processor}. "
            f"Please choose from {allowed_processors}"
        )
    instance_type = os.getenv(f"EC2_{processor.upper()}{job_type_str}_INSTANCE_TYPE")
    if arch_type == "graviton" or arch_type == "arm64":
        instance_type = os.getenv(
            f"EC2_{processor.upper()}_{arch_type.upper()}{job_type_str}_INSTANCE_TYPE"
        )
    if not instance_type:
        return []

    instance_list = filter_function([instance_type] if instance_type else [])
    return instance_list


def get_ec2_accelerator_type(default, processor):
    """
    Get EC2 instance type from associated EC2_EIA_INSTANCE_TYPE env variable, or set it to a default
    for contexts where the variable is not present (i.e. PR, Nightly, local testing)

    :param default: Default accelerator instance type to use
    :param processor: "eia"

    :return: one item list of instance type -- this is used to parametrize tests, and parameter is required to be
    a list.
    """
    allowed_processors = ("eia",)
    if processor not in allowed_processors:
        raise RuntimeError(
            f"Aborting EC2 test run. Unrecognized processor type {processor}. "
            f"Please choose from {allowed_processors}"
        )
    accelerator_type = os.getenv(f"EC2_{processor.upper()}_INSTANCE_TYPE")
    if not accelerator_type:
        if is_mainline_context():
            return []
        return [default]
    return [accelerator_type]


def launch_instance(
    ami_id,
    instance_type,
    ec2_key_name=None,
    region=DEFAULT_REGION,
    user_data=None,
    iam_instance_profile_name=None,
    instance_name="",
):
    """
    Launch an instance
    :param ami_id: AMI ID to be used for launched instance
    :param instance_type: Instance type of launched instance
    :param region: Region where instance will be launched
    :param user_data: Script to run when instance is launched as a str
    :param iam_instance_profile_arn: EC2 Role to be attached
    :param instance_name: Tag to display as Name on EC2 Console
    :return: <dict> Information about the instance that was launched
    """
    if not ami_id:
        raise Exception("No ami_id provided")
    if not ec2_key_name:
        raise Exception("Ec2 Key name must be provided")
    client = boto3.Session(region_name=region).client("ec2")
    LOGGER.info(f"Using AMI ID: {ami_id}")
    volume_name = "/dev/sda1" if ami_id in UL_AMI_LIST else "/dev/xvda"

    # Construct the dictionary with the arguments for API call
    arguments_dict = {
        "KeyName": ec2_key_name,
        "ImageId": ami_id,
        "InstanceType": instance_type,
        "MaxCount": 1,
        "MinCount": 1,
        "TagSpecifications": [
            {
                "ResourceType": "instance",
                "Tags": [{"Key": "Name", "Value": f"CI-CD {instance_name}"}],
            },
        ],
        "MetadataOptions": {
            "HttpTokens": "required",
            "HttpEndpoint": "enabled",
            "HttpPutResponseHopLimit": 2,
        },
        "BlockDeviceMappings": [
            {
                "DeviceName": volume_name,
                "Ebs": {
                    "VolumeSize": 150,
                },
            }
        ],
    }
    if user_data:
        arguments_dict["UserData"] = user_data
    if iam_instance_profile_name:
        arguments_dict["IamInstanceProfile"] = {"Name": iam_instance_profile_name}

    reservations = get_available_reservations(
        ec2_client=client, instance_type=instance_type, min_availability=arguments_dict["MinCount"]
    )

    while reservations:
        reservation = reservations.pop(0)
        arguments_dict["CapacityReservationSpecification"] = {
            "CapacityReservationTarget": {
                "CapacityReservationId": reservation["CapacityReservationId"]
            }
        }
        try:
            response = client.run_instances(**arguments_dict)
            LOGGER.info(
                f"Your {instance_type} reservation is ready, please wait to be seated. Launching..."
            )
            if is_mainline_context():
                LOGGER.info(f"Launched instance via {reservation}")
            return response["Instances"][0]
        except ClientError as e:
            LOGGER.error(f"Failed to launch via {instance_type} reservation - {e}")
            # Refresh available reservations
            time.sleep(randint(10, 30))
            reservations = get_available_reservations(
                ec2_client=client,
                instance_type=instance_type,
                min_availability=arguments_dict["MinCount"],
            )

    # Clean up cap reservation if we don't find one
    arguments_dict.pop("CapacityReservationSpecification", None)
    LOGGER.info(f"No capacity reservation available for {instance_type}, trying elsewhere...")
    response = client.run_instances(**arguments_dict)

    if not response or len(response["Instances"]) < 1:
        raise Exception(
            "Unable to launch the instance. \
                         Did not return any response"
        )

    return response["Instances"][0]


def get_available_reservations(ec2_client, instance_type, min_availability=1):
    """
    Get capacity reservations in our region that have our minimum availability

    Args:
        ec2_client (boto3.client): EC2 Boto3 client
        instance_type (string): instance type, i.e. g5.8xlarge
        min_availability (int, optional): Minimum number of instances to launch. Defaults to 1.

    Returns:
        list: list of dictionaries of reservations
    """
    reservations = ec2_client.describe_capacity_reservations()

    open_tables = [
        reservation
        for reservation in reservations["CapacityReservations"]
        if reservation["InstanceType"] == instance_type
        and reservation["AvailableInstanceCount"] >= min_availability
    ]

    # Sort by ascending instance count and total instance count,
    # so that we take minimum instances required, and leave other reservations
    # open for larger parties
    open_tables.sort(key=lambda res: res["TotalInstanceCount"])

    return sorted(open_tables, key=lambda res: res["AvailableInstanceCount"])


@retry(
    reraise=True,
    stop=stop_after_delay(30 * 60),  # Keep retrying for 30 minutes
    wait=wait_random_exponential(min=60, max=5 * 60),  # Retry after waiting 1-5 minutes
)
def launch_instances_with_retry(
    ec2_resource, ec2_client, availability_zone_options, ec2_create_instances_definition, fn_name=""
):
    """
    Helper function to launch EC2 instances with retry capability, to allow multiple attempts
    when facing instance capacity issues.
    :param ec2_resource: boto3 EC2 Service Resource object
    :param ec2_client: boto3 EC2 Client object
    :param availability_zone_options: list of availability zones in which to try to run instances
    :param ec2_create_instances_definition: dict of parameters to pass to
        ec2_resource.create_instances
    :param fn_name: string - function name for ease of logging
    :return: list of EC2 Instance Resource objects for instances launched
    """

    instances = None
    reservations = get_available_reservations(
        ec2_client=ec2_client,
        instance_type=ec2_create_instances_definition["InstanceType"],
        min_availability=ec2_create_instances_definition["MinCount"],
    )
    # Look at available CRs first
    while reservations:
        reservation = reservations.pop(0)
        ec2_create_instances_definition["CapacityReservationSpecification"] = {
            "CapacityReservationTarget": {
                "CapacityReservationId": reservation["CapacityReservationId"]
            }
        }
        try:
            instances = ec2_resource.create_instances(**ec2_create_instances_definition)
            LOGGER.info(
                f"Your reservation is ready for {fn_name}, please wait to be seated. Launching..."
            )
            if is_mainline_context():
                LOGGER.info(f"Launched instance for {fn_name} via {reservation}")
            return instances
        except ClientError as e:
            LOGGER.error(f"Failed to launch via reservation for {fn_name} - {e}")

    # Clean up capacity reservation if it failed
    ec2_create_instances_definition.pop("CapacityReservationSpecification", None)

    LOGGER.info(
        f"Looks like you didn't have a reservation for {fn_name}, let's see if we can seat you as a walk-in..."
    )

    if availability_zone_options:
        error = None
        for a_zone in availability_zone_options:
            ec2_create_instances_definition["Placement"] = {"AvailabilityZone": a_zone}
            try:
                instances = ec2_resource.create_instances(**ec2_create_instances_definition)
                if instances:
                    break
            except ClientError as e:
                LOGGER.error(f"Failed to launch in {a_zone} due to {e} for {fn_name}")
                error = e
                continue
        if not instances:
            raise error
    else:
        instances = ec2_resource.create_instances(**ec2_create_instances_definition)
    return instances


def launch_efa(ec2_client, ec2_instance_type, ec2_run_instances_definition, availability_zone):
    ec2_efa_run_instances_definition = copy.deepcopy(ec2_run_instances_definition)
    ec2_efa_run_instances_definition.update(
        {
            "Placement": {"AvailabilityZone": availability_zone},
            "NetworkInterfaces": generate_network_interfaces(
                ec2_client, ec2_instance_type, availability_zone
            ),
        }
    )
    response = ec2_client.run_instances(**ec2_efa_run_instances_definition) or {}
    return response.get("Instances")


def launch_efa_with_reservations(
    ec2_client, ec2_instance_type, reservations, ec2_run_instances_definition, fn_name=""
):
    ec2_run_instances_reserved_definition = copy.deepcopy(ec2_run_instances_definition)
    while reservations:
        reservation = reservations.pop(0)
        ec2_run_instances_reserved_definition["CapacityReservationSpecification"] = {
            "CapacityReservationTarget": {
                "CapacityReservationId": reservation["CapacityReservationId"]
            }
        }
        try:
            instances = launch_efa(
                ec2_client,
                ec2_instance_type,
                ec2_run_instances_reserved_definition,
                reservation["AvailabilityZone"],
            )
            if instances:
                LOGGER.info(
                    f"Your EFA reservation is ready for {fn_name}, please wait to be seated. Launching..."
                )
                if is_mainline_context():
                    LOGGER.info(f"Launched EFA enabled instance for {fn_name} via {reservation}")
                return instances
        except ClientError as e:
            LOGGER.debug(
                f"Failed to launch EFA instance for {fn_name} from reservation due to {e}\n"
                "Checking additional open reservations..."
            )
    return []


def validate_efa_instance_conditions(instances, minimum_number_of_instances):
    if len(instances) == minimum_number_of_instances:
        return True
    if len(instances) > minimum_number_of_instances:
        raise RuntimeError(
            f"Launched too many instances somehow, raising and cleaning up - {instances}; min/max_allowed = {minimum_number_of_instances}"
        )
    return False


class HeterogenousReservationError(Exception):
    pass


def referesh_capacity_reservations(ec2_client, ec2_instance_type, az):
    reservations = [
        reservation
        for reservation in get_available_reservations(ec2_client, ec2_instance_type)
        if reservation["AvailabilityZone"] == az
    ]

    available_instances = sum(
        [reservation["AvailableInstanceCount"] for reservation in reservations]
    )

    return reservations, available_instances


def launch_efa_with_heterogenous_reservations(ec2_client, ec2_run_instances_definition, fn_name=""):
    """
    Launch efa instances with heterogenous reservations

    Previous EFA launch code requires instances to be launched from the same command. This prohibits launching instances
    from multiple capacity reservations if the reservation has less than the minimum available instances required (typically 2).

    To remedy this, we group reservations by availability zone. If we have instances available in reservation, we
    group by most common availability zone and try to launch multiple instances from reservation. If we do not meet our minimum
    requirements, try launching from public pool to remedy the situation. If we launch 0 from reservation, do not
    try launching from the public pool, and allow other functions to handle launching exclusively from public.

    Args:
        ec2_client (boto3.client): boto3 ec2 client
        ec2_run_instances_definition (dict): key/value pairs for run instances launch cmd
        fn_name (str, optional): pytest function name. Defaults to "".

    Raises:
        HeterogenousReservationError: Custom error handling for function failure

    Returns:
        list: launched instances
    """
    ec2_heterogenous_run_instances_definition = copy.deepcopy(ec2_run_instances_definition)
    ec2_instance_type = ec2_heterogenous_run_instances_definition["InstanceType"]
    minimum_number_of_instances = ec2_heterogenous_run_instances_definition["MinCount"]

    # Reset max and min count to 1; We will
    ec2_heterogenous_run_instances_definition["MaxCount"] = 1
    ec2_heterogenous_run_instances_definition["MinCount"] = 1

    reserved_azs = [
        reservation["AvailabilityZone"]
        for reservation in ec2_client.describe_capacity_reservations()["CapacityReservations"]
        if reservation["InstanceType"] == ec2_instance_type
    ]

    tmp_reservations = get_available_reservations(
        ec2_client=ec2_client,
        instance_type=ec2_instance_type,
        min_availability=ec2_heterogenous_run_instances_definition["MinCount"],
    )

    az_counter = Counter(reservation["AvailabilityZone"] for reservation in tmp_reservations)
    az_priorities = [c[0] for c in az_counter.most_common()]

    # Track all reserved availability zones, in case capacity comes later
    for reserved_az in reserved_azs:
        if reserved_az not in az_priorities:
            az_priorities.append(reserved_az)

    for az in az_priorities:
        LOGGER.info(f"Checking AZ {az}")
        # Refresh reservations for each AZ
        reservations, available_instances = referesh_capacity_reservations(
            ec2_client, ec2_instance_type, az
        )
        ec2_heterogenous_run_instances_definition["MaxCount"] = 1
        ec2_heterogenous_run_instances_definition["MinCount"] = 1
        instances = []
        try:
            while available_instances and len(instances) < minimum_number_of_instances:
                LOGGER.info(f"trying to launch {ec2_instance_type} in {az}")
                instance = launch_efa_with_reservations(
                    ec2_client=ec2_client,
                    ec2_instance_type=ec2_instance_type,
                    reservations=reservations,
                    ec2_run_instances_definition=ec2_heterogenous_run_instances_definition,
                    fn_name=fn_name,
                )
                instances += instance

                # Refresh reservations for each AZ
                reservations, available_instances = referesh_capacity_reservations(
                    ec2_client, ec2_instance_type, az
                )

            if validate_efa_instance_conditions(instances, minimum_number_of_instances):
                LOGGER.info("Strung together some reservations, let's go")
                return instances

            # If we have remaining instances, try launching from public pool
            # Try a different availability zone if we don't have any reservation launches, however. Always
            # prioritize reservation launches in this function.
            remaining_instances = minimum_number_of_instances - len(instances)
            if remaining_instances != minimum_number_of_instances:
                LOGGER.info(
                    f"Have {remaining_instances} remaining_instances instances in {az}. Trying from public pool."
                )
                ec2_heterogenous_run_instances_definition["MaxCount"] = remaining_instances
                ec2_heterogenous_run_instances_definition["MinCount"] = remaining_instances
                instances += launch_efa(
                    ec2_client, ec2_instance_type, ec2_heterogenous_run_instances_definition, az
                )

                if validate_efa_instance_conditions(instances, minimum_number_of_instances):
                    LOGGER.info("Strung together some reservations and some walk-ins, let's go")
                    return instances

                # Clean up instances if this workflow did not succeed
                LOGGER.info(
                    f"Failed to launch enough instances from public and reservations for {fn_name}."
                )
                if instances:
                    LOGGER.info(
                        f"Cleaning up instances {(instance['InstanceId'] for instance in instances)}..."
                    )
                    ec2_client.terminate_instances(
                        InstanceIds=[instance_info["InstanceId"] for instance_info in instances]
                    )

        except ClientError as e:
            # Clean up any remaining instances
            LOGGER.info(
                f"Failed to launch EFA instance for {fn_name} from reservation due to {e}\n"
                "Checking additional open reservations and cleaning up stray resources"
            )
            if instances:
                LOGGER.info(
                    f"Cleaning up instances {(instance['InstanceId'] for instance in instances)}..."
                )
                ec2_client.terminate_instances(
                    InstanceIds=[instance_info["InstanceId"] for instance_info in instances]
                )

        except Exception as e:
            if instances:
                LOGGER.info(
                    f"Cleaning up instances {(instance['InstanceId'] for instance in instances)}..."
                )
                ec2_client.terminate_instances(
                    InstanceIds=[instance_info["InstanceId"] for instance_info in instances]
                )
            raise HeterogenousReservationError("Failed to launch via heterogenous approach") from e
    return []


@retry(
    reraise=True,
    stop=stop_after_delay(30 * 60),  # Keep retrying for 30 minutes
    wait=wait_random_exponential(min=60, max=5 * 60),  # Retry after waiting 1-10 minutes
)
def launch_efa_instances_with_retry(
    ec2_client,
    ec2_instance_type,
    availability_zone_options,
    ec2_run_instances_definition,
    fn_name="",
):
    """
    Helper function to launch EFA-capable EC2 instances with retry capability, to allow
    multiple attempts when facing instance capacity issues.
    :param ec2_client: boto3 EC2 Client object
    :param ec2_instance_type: str EC2 Instance Type
    :param availability_zone_options: list of availability zones in which to try to run instances
    :param ec2_run_instances_definition: dict of parameters to pass to ec2_client.run_instances
    :param fn_name: string - function name for ease of logging
    :return: dict response from ec2_client.run_instances
    """
    region = ec2_client.meta.region_name
    LOGGER.info(f"Trying to launch {ec2_instance_type} for {fn_name} via capacity reservation...")

    heterogenous_reservation_launch = launch_efa_with_heterogenous_reservations(
        ec2_client=ec2_client,
        ec2_run_instances_definition=ec2_run_instances_definition,
        fn_name=fn_name,
    )

    if heterogenous_reservation_launch:
        return heterogenous_reservation_launch

    LOGGER.info(
        f"Looks like you didn't have an EFA reservation for {fn_name}, let's see if we can seat you as a walk-in..."
    )

    instances = []
    for availability_zone in availability_zone_options:
        try:
            instances = launch_efa(
                ec2_client, ec2_instance_type, ec2_run_instances_definition, availability_zone
            )
            if instances:
                break
        except ClientError as e:
            LOGGER.info(
                f"Failed to launch in {availability_zone} for {fn_name} due to {e}\n"
                "Retrying in the next availability zone."
            )
            continue
    if not instances:
        raise RuntimeError(
            f"Unable to launch {ec2_instance_type} instances in {region} for {fn_name}"
        )
    return instances


def get_ec2_client(region):
    return boto3.client("ec2", region_name=region, config=Config(retries={"max_attempts": 10}))


def get_instance_from_id(instance_id, region=DEFAULT_REGION):
    """
    Get instance information using instance ID
    :param instance_id: Instance ID to be queried
    :param region: Region where query will be performed
    :return: <dict> Information about instance with matching instance ID
    """
    if not instance_id:
        raise Exception("No instance id provided")
    client = boto3.Session(region_name=region).client("ec2")
    instance = client.describe_instances(InstanceIds=[instance_id])
    if not instance:
        raise Exception(
            "Unable to launch the instance. \
                         Did not return any reservations object"
        )
    return instance["Reservations"][0]["Instances"][0]


@retry(stop=stop_after_attempt(16), wait=wait_fixed(60))
def get_private_ip(instance_id, region=DEFAULT_REGION):
    """
    Get Private IP of instance using instance ID
    :param instance_id: Instance ID to be queried
    :param region: Region where query will be performed
    :return: <str> Private IP Address of instance with matching instance ID
    """
    instance = get_instance_from_id(instance_id, region)
    if not instance["PrivateIpAddress"]:
        raise Exception("Private IP address not yet available")
    return instance["PrivateIpAddress"]


@retry(stop=stop_after_attempt(16), wait=wait_fixed(60))
def get_public_ip(instance_id, region=DEFAULT_REGION):
    """
    Get Public IP of instance using instance ID
    :param instance_id: Instance ID to be queried
    :param region: Region where query will be performed
    :return: <str> IP Address of instance with matching instance ID
    """
    instance = get_instance_from_id(instance_id, region)
    if not instance["PublicIpAddress"]:
        raise Exception("IP address not yet available")
    return instance["PublicIpAddress"]


@retry(stop=stop_after_attempt(16), wait=wait_fixed(60))
def get_public_ip_from_private_dns(private_dns, region=DEFAULT_REGION):
    """
    Get Public IP of instance using private DNS
    :param private_dns:
    :param region:
    :return: <str> IP Address of instance with matching private DNS
    """
    client = boto3.Session(region_name=region).client("ec2")
    response = client.describe_instances(
        Filters={"Name": "private-dns-name", "Value": [private_dns]}
    )
    return response.get("Reservations")[0].get("Instances")[0].get("PublicIpAddress")


@retry(stop=stop_after_attempt(16), wait=wait_fixed(60))
def get_instance_user(instance_id, region=DEFAULT_REGION):
    """
    Get "ubuntu" or "ec2-user" based on AMI used to launch instance
    :param instance_id: Instance ID to be queried
    :param region: Region where query will be performed
    :return: <str> user name
    """
    instance = get_instance_from_id(instance_id, region)
    user = "ubuntu" if instance["ImageId"] in UL_AMI_LIST else "ec2-user"
    return user


def get_instance_state(instance_id, region=DEFAULT_REGION):
    """
    Get state of instance using instance ID
    :param instance_id: Instance ID to be queried
    :param region: Region where query will be performed
    :return: <str> State of instance with matching instance ID
    """
    instance = get_instance_from_id(instance_id, region)
    return instance["State"]["Name"]


@retry(stop=stop_after_attempt(16), wait=wait_fixed(60))
def check_instance_state(instance_id, state="running", region=DEFAULT_REGION):
    """
    Compares the instance state with the state argument.
    Retries 8 times with 120 seconds gap between retries.
    :param instance_id: Instance ID to be queried
    :param state: Expected instance state
    :param region: Region where query will be performed
    :return: <str> State of instance with matching instance ID
    """
    instance_state = get_instance_state(instance_id, region)
    if state != instance_state:
        raise Exception(f"Instance {instance_id} not in {state} state")
    return instance_state


def get_system_state(instance_id, region=DEFAULT_REGION):
    """
    Returns health checks state for instances
    :param instance_id: Instance ID to be queried
    :param region: Region where query will be performed
    :return: <tuple> System state and Instance state of instance with matching instance ID
    """
    if not instance_id:
        raise Exception("No instance id provided")
    client = boto3.Session(region_name=region).client("ec2")
    response = client.describe_instance_status(InstanceIds=[instance_id])
    if not response:
        raise Exception(
            "Unable to launch the instance. \
                         Did not return any reservations object"
        )
    instance_status_list = response["InstanceStatuses"]
    if not instance_status_list:
        raise Exception(
            "Unable to launch the instance. \
                         Did not return any reservations object"
        )
    if len(instance_status_list) < 1:
        raise Exception(
            "The instance id seems to be incorrect {}. \
                         reservations seems to be empty".format(
                instance_id
            )
        )

    instance_status = instance_status_list[0]
    return (
        instance_status["SystemStatus"]["Status"],
        instance_status["InstanceStatus"]["Status"],
    )


@retry(stop=stop_after_attempt(96), wait=wait_fixed(10))
def check_system_state(
    instance_id, system_status="ok", instance_status="ok", region=DEFAULT_REGION
):
    """
    Compares the system state (Health Checks).
    Retries 96 times with 10 seconds gap between retries
    :param instance_id: Instance ID to be queried
    :param system_status: Expected system state
    :param instance_status: Expected instance state
    :param region: Region where query will be performed
    :return: <tuple> System state and Instance state of instance with matching instance ID
    """
    instance_state = get_system_state(instance_id, region=region)
    if system_status != instance_state[0] or instance_status != instance_state[1]:
        raise Exception(
            "Instance {} not in \
                         required state".format(
                instance_id
            )
        )
    return instance_state


def terminate_instance(instance_id, region=DEFAULT_REGION):
    """
    Terminate EC2 instances with matching instance ID
    :param instance_id: Instance ID to be terminated
    :param region: Region where instance is located
    """
    if not instance_id:
        raise Exception("No instance id provided")
    client = boto3.Session(region_name=region).client("ec2")
    response = client.terminate_instances(InstanceIds=[instance_id])
    if not response:
        raise Exception("Unable to terminate instance. No response received.")
    instances_terminated = response["TerminatingInstances"]
    if not instances_terminated:
        raise Exception("Failed to terminate instance.")
    if instances_terminated[0]["InstanceId"] != instance_id:
        raise Exception("Failed to terminate instance. Unknown error.")


def get_instance_type_details(instance_type, region=DEFAULT_REGION):
    """
    Get instance type details for a given instance type
    :param instance_type: Instance type to be queried
    :param region: Region where query will be performed
    :return: <dict> Information about instance type
    """
    client = boto3.client("ec2", region_name=region)
    response = client.describe_instance_types(InstanceTypes=[instance_type])
    if not response or not response["InstanceTypes"]:
        raise Exception("Unable to get instance details. No response received.")
    if response["InstanceTypes"][0]["InstanceType"] != instance_type:
        raise Exception(
            f"Bad response received. Requested {instance_type} "
            f"but got {response['InstanceTypes'][0]['InstanceType']}"
        )
    return response["InstanceTypes"][0]


def get_instance_details(instance_id, region=DEFAULT_REGION):
    """
    Get instance details for instance with given instance ID
    :param instance_id: Instance ID to be queried
    :param region: Region where query will be performed
    :return: <dict> Information about instance with matching instance ID
    """
    if not instance_id:
        raise Exception("No instance id provided")
    instance = get_instance_from_id(instance_id, region=region)
    if not instance:
        raise Exception("Could not find instance")

    return get_instance_type_details(instance["InstanceType"], region=region)


@retry(stop=stop_after_attempt(30), wait=wait_fixed(10))
def get_instance_num_cpus(instance_id, region=DEFAULT_REGION):
    """
    Get number of VCPUs on instance with given instance ID
    :param instance_id: Instance ID to be queried
    :param region: Region where query will be performed
    :return: <int> Number of VCPUs on instance with matching instance ID
    """
    instance_info = get_instance_details(instance_id, region=region)
    return instance_info["VCpuInfo"]["DefaultVCpus"]


@retry(stop=stop_after_attempt(30), wait=wait_fixed(10))
def get_instance_memory(instance_id, region=DEFAULT_REGION):
    """
    Get total RAM available on instance with given instance ID
    :param instance_id: Instance ID to be queried
    :param region: Region where query will be performed
    :return: <int> Total RAM available on instance with matching instance ID
    """
    instance_info = get_instance_details(instance_id, region=region)
    return instance_info["MemoryInfo"]["SizeInMiB"]


@retry(stop=stop_after_attempt(30), wait=wait_fixed(10))
def get_instance_num_inferentias(instance_id=None, instance_type=None, region=DEFAULT_REGION):
    """
    Get total number of neurons on instance with given instance ID
    :param instance_id: Instance ID to be queried
    :param instance_type: Instance Type to be queried
    :param region: Region where query will be performed
    :return: <int> Number of neurons on instance with matching instance ID
    """
    assert instance_id or instance_type, "Input must be either instance_id or instance_type"
    instance_info = (
        get_instance_type_details(instance_type, region=region)
        if instance_type
        else get_instance_details(instance_id, region=region)
    )
    return sum(
        neuron_type["Count"]
        for neuron_type in instance_info["InferenceAcceleratorInfo"]["Accelerators"]
        if neuron_type["Name"] == "Inferentia"
    )


@retry(stop=stop_after_attempt(30), wait=wait_fixed(10))
def get_instance_num_gpus(instance_id=None, instance_type=None, region=DEFAULT_REGION):
    """
    Get total number of GPUs on instance with given instance ID
    :param instance_id: Instance ID to be queried
    :param instance_type: Instance Type to be queried
    :param region: Region where query will be performed
    :return: <int> Number of GPUs on instance with matching instance ID
    """
    assert instance_id or instance_type, "Input must be either instance_id or instance_type"
    instance_info = (
        get_instance_type_details(instance_type, region=region)
        if instance_type
        else get_instance_details(instance_id, region=region)
    )
    return sum(gpu_type["Count"] for gpu_type in instance_info["GpuInfo"]["Gpus"])


@retry(stop=stop_after_attempt(30), wait=wait_fixed(10))
def get_num_efa_interfaces_for_instance_type(instance_type, region=DEFAULT_REGION):
    """
    Get the maximum number of EFA interfaces available on a particular instance type
    :param instance_type: str EC2 Instance type
    :param region: str Region where ec2 instance must be launched
    :return: NoneType/int Number of EFA interfaces that can be created on the given instance type.
    Can be None if instance_type doesn't support EFA.
    """
    instance_info = get_instance_type_details(instance_type, region)
    num_efa_interfaces = (
        instance_info.get("NetworkInfo", {}).get("EfaInfo", {}).get("MaximumEfaInterfaces")
    )
    return num_efa_interfaces


def get_ec2_fabric_connection(instance_id, instance_pem_file, region):
    """
    establish connection with EC2 instance if necessary
    :param instance_id: ec2_instance id
    :param instance_pem_file: instance key name
    :param region: Region where ec2 instance is launched
    :return: Fabric connection object
    """
    user = get_instance_user(instance_id, region=region)
    conn = Connection(
        user=user,
        host=get_public_ip(instance_id, region),
        connect_kwargs={"key_filename": [instance_pem_file]},
        connect_timeout=18000,
    )
    return conn


def get_ec2_instance_tags(instance_id, region=DEFAULT_REGION, ec2_client=None):
    ec2_client = ec2_client or get_ec2_client(region)
    response = ec2_client.describe_tags(Filters=[{"Name": "resource-id", "Values": [instance_id]}])
    return {tag["Key"]: tag["Value"] for tag in response.get("Tags")}


# If IMDSv2 is enforced on EC2 instance with hop limit 1 then IMDSv2 api calls doesn't work
# If IMDSv2 is enforced on EC2 instance with hop limit > 1 then IMDSv2 api calls work
@retry(stop=stop_after_attempt(16), wait=wait_fixed(60))
def enforce_IMDSv2(instance_id, hop_limit, region=DEFAULT_REGION, ec2_client=None):
    """
    Enable HTTP TOKENS required option on EC2 instance with given hop limit.

    :param instance_id: str, ec2 instance id
    :param region: str, Region where ec2 instance is launched.
    :param ec2_client: str, ec2 client.
    :param hop_limit: str, hop limit to be set on ec2 instance.
    """
    ec2_client = ec2_client or get_ec2_client(region)
    response = ec2_client.modify_instance_metadata_options(
        InstanceId=instance_id,
        HttpTokens="required",
        HttpPutResponseHopLimit=hop_limit,
        HttpEndpoint="enabled",
    )

    if not response:
        raise Exception("Unable to enforce IMDSv2. No response received.")

    time.sleep(2)
    state = None
    if response["InstanceId"]:
        res = ec2_client.describe_instances(InstanceIds=[instance_id])
        if res:
            metadata_options = res["Reservations"][0]["Instances"][0]["MetadataOptions"]
            state = metadata_options["State"]
            LOGGER.info(f"Modify Metadata options of EC2 instance: {metadata_options}")
    if state != "applied":
        raise Exception(
            "Unable to enforce IMDSv2. Describe instance is not able to confirm if IMDSv2 enforced."
        )


@retry(stop=stop_after_attempt(16), wait=wait_fixed(60))
def enforce_IMDSv1(instance_id, region=DEFAULT_REGION, ec2_client=None):
    """
    Enabled IMDSv1 on EC2 instance.

    :param instance_id: str, ec2 instance id
    :param region: str, Region where ec2 instance is launched.
    :param ec2_client: str, ec2 client.
    :param hop_limit: str, hop limit to be set on ec2 instance.
    """
    ec2_client = ec2_client or get_ec2_client(region)
    response = ec2_client.modify_instance_metadata_options(
        InstanceId=instance_id, HttpTokens="optional", HttpPutResponseHopLimit=1
    )

    if not response:
        raise Exception("Unable to enforce IMDSv1. No response received.")
    time.sleep(2)
    state = None
    if response["InstanceId"]:
        res = ec2_client.describe_instances(InstanceIds=[instance_id])
        if res:
            metadata_options = res["Reservations"][0]["Instances"][0]["MetadataOptions"]
            state = metadata_options["State"]
            LOGGER.info(f"Modify Metadata options of EC2 instance: {metadata_options}")
    if state != "applied":
        raise Exception(
            "Unable to enforce IMDSv1. Describe instance is not able to confirm if IMDSv1 enforced."
        )


def fetch_s3_file_and_get_last_line(s3_location, local_filename="temp.txt"):
    """
    Fetches the s3 file locally and extracts its last line.

    :param s3_location: str, s3 uri
    :param local_filename: str, location where s3 file is to be downloaded locally.
    :return: str, The last line of the file
    """
    run(f"rm -rf {local_filename}", hide=True)
    run(f"aws s3 cp {s3_location} {local_filename}", hide=True)
    last_line_of_file = run(f"tail -n1 {local_filename}", hide=True).stdout.strip()
    return last_line_of_file


def execute_asynchronus_testing_using_s3_bucket(
    connection,
    execution_command,
    connection_timeout,
    required_log_ending,
    loop_time=2.5 * 3600,
    log_location_within_ec2="~/container_tests/logs.txt",
    s3_uri_for_saving_permanent_logs=None,
    hang_detection_window=3,
):
    """
    This method uses fabric to run the provided execution_command in asynchronus mode. While the execution command
    is being executed in the image, it keeps on uploading the logs to the s3 bucket at fixed intervals. After a
    loop_time is over, it checks the last line of the uploaded logs to see if it is same as required_log_ending.
    This is mainly used in cases where Fabric behaves in an undesired way due to long living connections.

    :param connection: Fabric connection object
    :param execution_command: str, command that connection.run() will execute
    :param connection_timeout: timeout for fabric connection
    :param required_log_ending: str, The string that is desired to be present at the end of the logs
    :param loop_time: int, seconds for which we would wait for the tests to execute on ec2 instance
    :param log_location_within_ec2: Location within ec2 instance where the logs are being witten.
    :param s3_uri_for_saving_permanent_logs: Location where permanent s3 logs could be saved.
    :param hang_detection_window: int, This method detects a hang if length of log file does not change for hang_detection_window number of iterations.
    """
    account_id = os.getenv("ACCOUNT_ID", boto3.client("sts").get_caller_identity()["Account"])
    s3_bucket_name = f"dlc-async-test-{account_id}"
    if not s3_uri_for_saving_permanent_logs:
        unique_id = str(uuid.uuid4())
        unique_id_with_timestamp = f"{unique_id}-{int(time.time())}"
        s3_location = f"s3://{s3_bucket_name}/{unique_id_with_timestamp}.txt"
    else:
        s3_location = s3_uri_for_saving_permanent_logs
    connection.run(execution_command, hide=True, timeout=connection_timeout, asynchronous=True)
    start_time = int(time.time())
    loop_count = 0
    local_filename = s3_location.replace(":", "-").replace("/", "-")
    last_line_of_log = ""
    line_count_list = []
    while (int(time.time()) - start_time <= loop_time) and (
        not last_line_of_log.endswith(required_log_ending)
    ):
        time.sleep(5 * 60)
        loop_count += 1
        connection.run(
            f"aws s3 cp {log_location_within_ec2} {s3_location}", timeout=connection_timeout
        )
        last_line_of_log = fetch_s3_file_and_get_last_line(s3_location, local_filename)
        number_of_lines_in_log_file = int(
            run(f"wc -l {local_filename}", hide=True).stdout.strip().split()[0]
        )
        line_count_list.append(number_of_lines_in_log_file)
        number_of_previous_line_counts_to_check = hang_detection_window
        if len(line_count_list) >= number_of_previous_line_counts_to_check:
            if all(
                line_count == line_count_list[-1]
                for line_count in line_count_list[-number_of_previous_line_counts_to_check:]
            ):
                # If last 3 runs lead to same line number then it demonstrates no progress and hence we stop.
                LOGGER.info(
                    f"No progress reported for past {number_of_previous_line_counts_to_check} iterations. Job most likely hanged so stopping the execution!!"
                )
                break
        LOGGER.info(f"Uploaded file to {s3_location} for {loop_count} number of times")

    if not last_line_of_log.endswith(required_log_ending):
        raise ValueError(
            f""" Test failed because the last row is not as expected. \n"""
            f""" Last row in the log file ===> {last_line_of_log} \n"""
            f""" expected ===> {required_log_ending}. \n"""
            f""" Full log ===> {s3_location} \n"""
        )


def get_s3_uri_for_saving_permanent_logs(
    framework, s3_bucket, test_type="ec2", custom_filename=None
):
    """
    Helper function to get s3 uri where log files generated within test ec2 instances will be uploaded to.

    :param framework: str, tensorflow, pytorch etc.
    :param s3_bucket: str, name of the bucket where we want to upload the logs.
    :param test_type: str, type of the test
    :param custom_filename: str, custom name of the file that will be prepended with unique id to create the s3 filepath
    """
    commit_id = os.getenv("CODEBUILD_RESOLVED_SOURCE_VERSION", f"default-{int(time.time())}")
    unique_id = str(uuid.uuid4())
    unique_id_with_timestamp = f"{unique_id}-{int(time.time())}"
    if custom_filename:
        filename = f"{custom_filename}-logs-{unique_id_with_timestamp}.txt"
    else:
        filename = f"logs-{unique_id_with_timestamp}.txt"
    s3_filepath = os.path.join(s3_bucket, test_type, framework, commit_id, filename)
    s3_permanent_log_upload_uri = f"s3://{s3_filepath}"
    return s3_permanent_log_upload_uri


def execute_ec2_training_test(
    connection,
    ecr_uri,
    test_cmd,
    region=DEFAULT_REGION,
    executable="bash",
    large_shm=False,
    host_network=False,
    container_name="ec2_training_container",
    timeout=18000,
    bin_bash_entrypoint=False,
    enable_habana_async_execution=False,
    enable_gdrcopy=False,
):
    if executable not in ("bash", "python"):
        raise RuntimeError(
            f"This function only supports executing bash or python commands on containers"
        )
    if executable == "bash":
        executable = os.path.join(os.sep, "bin", "bash")
    docker_runtime = "--runtime=nvidia --gpus all" if "gpu" in ecr_uri else ""
    container_test_local_dir = os.path.join("$HOME", "container_tests")
    synapseai_version = get_synapseai_version_from_tag(ecr_uri)
    # Make sure we are logged into ECR so we can pull the image
    account_id = get_account_id_from_image_uri(ecr_uri)
    login_to_ecr_registry(connection, account_id, region)

    # Run training command
    shm_setting = '--shm-size="1g"' if large_shm else ""
    network = '--network="host" ' if host_network else ""
    container_runtime = "--runtime=habana -e HABANA_VISIBLE_DEVICES=all" if "hpu" in ecr_uri else ""
    ompi_mca_btl = "-e OMPI_MCA_btl_vader_single_copy_mechanism=none" if "hpu" in ecr_uri else ""
    cap_add = "--cap-add=sys_nice" if "hpu" in ecr_uri else ""
    ipc = "--ipc=host" if "hpu" in ecr_uri and "pytorch" in ecr_uri else ""
    hpu_env_vars = f"-e GIT_BRANCH={synapseai_version}" if "hpu" in ecr_uri else ""
    habana_container_test_repo = (
        "-v ${HOME}/gaudi-test-suite:/gaudi-test-suite" if "hpu" in ecr_uri else ""
    )
    neuron_device = "--device=/dev/neuron0" if "neuron" in ecr_uri else ""
    gdr_device = "--device=/dev/gdrdrv" if enable_gdrcopy else ""
    bin_bash_cmd = "--entrypoint /bin/bash " if bin_bash_entrypoint else ""

    LOGGER.info(f"execute_ec2_training_test pulling {ecr_uri}, with cmd {test_cmd}")
    connection.run(f"docker pull {ecr_uri}", hide="out")
    connection.run(
        f"docker run {docker_runtime} --name {container_name} "
        f"{container_runtime} {ompi_mca_btl} {cap_add} {hpu_env_vars} "
        f"{ipc} {network}-v {container_test_local_dir}:{os.path.join(os.sep, 'test')} "
        f"{habana_container_test_repo} {shm_setting} {neuron_device} {gdr_device} -itd {bin_bash_cmd}{ecr_uri}",
        hide=True,
    )

    if "habana" in ecr_uri:
        execution_command = f"docker exec --user root {container_name} {executable} -c '{test_cmd}'"
        required_log_ending = "Kudos!! Habana tests executed successfully"
        framework = (
            "tensorflow" if "tensorflow" in ecr_uri else "pytorch" if "pytorch" in ecr_uri else None
        )
        test_type = "ec2"
        account_id_prefix = os.getenv(
            "ACCOUNT_ID", boto3.client("sts").get_caller_identity()["Account"]
        )[:3]
        s3_bucket_for_permanent_logs = f"dlinfra-habana-tests-{account_id_prefix}"
        s3_uri_permanent_logs = get_s3_uri_for_saving_permanent_logs(
            framework, s3_bucket=s3_bucket_for_permanent_logs, test_type=test_type
        )
        if enable_habana_async_execution == True:
            execute_asynchronus_testing_using_s3_bucket(
                connection,
                execution_command,
                timeout,
                required_log_ending,
                loop_time=4 * 3600,
                s3_uri_for_saving_permanent_logs=s3_uri_permanent_logs,
                hang_detection_window=15,
            )
            return
        else:
            run_output = connection.run(execution_command, hide=True, timeout=timeout)
            try:
                connection.run(f"aws s3 cp ~/container_tests/logs.txt {s3_uri_permanent_logs}")
                LOGGER.info(f"Uploaded logs at: {s3_uri_permanent_logs}")
            except:
                LOGGER.info(f"Could not upload the logs")
            return run_output

    # Hack not sure why but see the following. since not using latest driver yet in the AMI, doing this for now
    # [  214.939271] Neuron Driver Started with Version:2.x.381.0-b70a76a18efb5e89ffed987461e9a1009d8b6f1e
    # [  214.939619] neuron-driver 0000:00:1e.0: BAR 4: can't reserve [mem 0x1000000000-0x17ffffffff 64bit pref]
    if "neuron" in ecr_uri:
        connection.run(f"sudo modprobe -r neuron  && sudo modprobe -i neuron")

    LOGGER.info(f"execute_ec2_training_test running {ecr_uri}, with cmd {test_cmd}")
    ec2_res = connection.run(
        f"docker exec --user root {container_name} {executable} -c '{test_cmd}'",
        hide=True,
        timeout=timeout,
    )
    LOGGER.info(f"execute_ec2_training_test completed {ecr_uri}, with cmd {test_cmd}")
    return ec2_res


def execute_ec2_telemetry_test(
    connection,
    ecr_uri,
    call_type,
    container_name,
    test_cmd,
    opt_in=False,
    region=DEFAULT_REGION,
    timeout=900,
):
    """
    Execute telemetry tests on EC2 instances using Docker containers.

    Args:
        connection: EC2 connection object
        ecr_uri (str): ECR image URI
        call_type (str): Type of test to run ('bashrc', 'entrypoint', 'framework')
        container_name (str): Base name for the container
        test_cmd (str): Test command to execute
        opt_in (bool): Whether to run in opt-in mode (default: False)
        region (str): AWS region
        timeout (int): Timeout in seconds (default: 900)

    Returns:
        Result object from the connection.run command

    Raises:
        RuntimeError: If invalid call_type is provided
    """
    # Validate call type
    VALID_CALL_TYPES = {"bashrc", "entrypoint", "framework"}
    if call_type not in VALID_CALL_TYPES:
        raise RuntimeError(f"Invalid call_type. Must be one of: {', '.join(VALID_CALL_TYPES)}")

    # Set up Docker runtime configuration
    docker_runtime = "--runtime=nvidia --gpus all" if "gpu" in ecr_uri else ""
    if "pytorch" in ecr_uri:
        framework_env = f"-e FRAMEWORK='torch'"
    elif "tensorflow" in ecr_uri:
        framework_env = f"-e FRAMEWORK='tensorflow'"
    else:
        framework_env = ""
    opt_out_env = "" if opt_in else "-e OPT_OUT_TRACKING='true'"

    # Set up container and mount configuration
    test_suffix = "opt_in" if opt_in else "opt_out"
    container_name = (
        f"{container_name}_{call_type}_{test_suffix}"
        if call_type in {"bashrc", "entrypoint"}
        else f"{container_name}_{call_type}"
    )

    container_test_local_dir = os.path.join("$HOME", "container_tests")
    mount_path = f"-v {container_test_local_dir}:{os.path.join(os.sep, 'test')}"

    # Prepare test command
    test_cmd = f"{test_cmd} {call_type} {test_suffix}"
    LOGGER.info(f"Executing test: {test_cmd}")

    # for entrypoint test, we aviod invoking bashrc telemetry
    nobashrc_cmd = f"bash --norc" if call_type == "entrypoint" else ""

    # for other tests, we need to aviod using entrypoint telemetry
    entrypoint_override = f"--entrypoint /bin/bash" if call_type != "entrypoint" else ""

    try:
        # Login to ECR and pull image
        account_id = get_account_id_from_image_uri(ecr_uri)
        login_to_ecr_registry(connection, account_id, region)

        LOGGER.info(f"Pulling image: {ecr_uri}")
        connection.run(f"docker pull {ecr_uri}", hide="out")

        # Execute test based on call type
        # Start container
        connection.run(
            f"docker run {docker_runtime} --name {container_name} "
            f" {mount_path} "
            f"-itd -e TEST_MODE='1' {framework_env} {opt_out_env} {entrypoint_override} {ecr_uri} {nobashrc_cmd}",
            hide=True,
        )

        # Execute test command
        ec2_res = connection.run(
            f"docker exec --user root {container_name} bash -c '{test_cmd}'",
            hide=True,
            timeout=timeout,
        )

        LOGGER.info(f"Test completed for {call_type} on {ecr_uri}")
        return ec2_res

    except Exception as e:
        LOGGER.error(f"Test failed: {str(e)}")
        raise


def execute_ec2_inference_test(connection, ecr_uri, test_cmd, region=DEFAULT_REGION):
    docker_runtime = "--runtime=nvidia --gpus all" if "gpu" in ecr_uri else ""
    container_test_local_dir = os.path.join("$HOME", "container_tests")

    # Make sure we are logged into ECR so we can pull the image
    account_id = get_account_id_from_image_uri(ecr_uri)
    login_to_ecr_registry(connection, account_id, region)

    # Run training command
    connection.run(
        f"docker run {docker_runtime} --name ec2_inference_container -v {container_test_local_dir}:{os.path.join(os.sep, 'test')}"
        f" -itd {ecr_uri} bash",
        hide=True,
    )
    connection.run(
        f"docker exec --user root ec2_inference_container {os.path.join(os.sep, 'bin', 'bash')} -c '{test_cmd}'",
        hide=True,
        timeout=3000,
    )


def execute_ec2_training_performance_test(
    connection,
    ecr_uri,
    test_cmd,
    region=DEFAULT_REGION,
    post_process=None,
    data_source="",
    threshold=None,
):
    docker_runtime = "--runtime=nvidia --gpus all" if "gpu" in ecr_uri else ""
    container_test_local_dir = os.path.join("$HOME", "container_tests")

    timestamp = time.strftime("%Y-%m-%d-%H-%M-%S")
    log_name = (
        f"{data_source}_results_{os.getenv('CODEBUILD_RESOLVED_SOURCE_VERSION')}_{timestamp}.txt"
    )
    log_location = os.path.join(container_test_local_dir, "benchmark", "logs", log_name)

    # Make sure we are logged into ECR so we can pull the image
    account_id = get_account_id_from_image_uri(ecr_uri)
    login_to_ecr_registry(connection, account_id, region)

    connection.run(f"docker pull {ecr_uri}", hide=True)

    # Run training command, display benchmark results to console
    connection.run(
        f"docker run {docker_runtime} --user root "
        f"-e LOG_FILE={os.path.join(os.sep, 'test', 'benchmark', 'logs', log_name)} "
        f"-e PR_CONTEXT={1 if is_pr_context() else 0} "
        f"-v {container_test_local_dir}:{os.path.join(os.sep, 'test')} {ecr_uri} "
        f"{os.path.join(os.sep, 'bin', 'bash')} -c {test_cmd}"
    )
    ec2_performance_upload_result_to_s3_and_validate(
        connection,
        ecr_uri,
        log_location,
        data_source,
        threshold,
        post_process,
        log_name,
    )


def execute_ec2_habana_training_performance_test(
    connection,
    ecr_uri,
    test_cmd,
    region=DEFAULT_REGION,
    data_source="",
    cards_num=None,
    timeout=18000,
):
    container_test_local_dir = os.path.join("$HOME", "container_tests")

    timestamp = time.strftime("%Y-%m-%d-%H-%M-%S")
    log_name = (
        f"{data_source}_results_{os.getenv('CODEBUILD_RESOLVED_SOURCE_VERSION')}_{timestamp}.txt"
    )
    synapseai_version = get_synapseai_version_from_tag(ecr_uri)
    # Make sure we are logged into ECR so we can pull the image
    account_id = get_account_id_from_image_uri(ecr_uri)
    login_to_ecr_registry(connection, account_id, region)

    connection.run(f"docker pull -q {ecr_uri}")

    container_runtime = "--runtime=habana -e HABANA_VISIBLE_DEVICES=all"
    hpu_env_vars = f"-e CARDS_NUM={cards_num} -e GIT_BRANCH={synapseai_version}"
    ompi_mca_btl = "-e OMPI_MCA_btl_vader_single_copy_mechanism=none"
    cap_add = "--cap-add=sys_nice"
    ipc = "--ipc=host" if "pytorch" in ecr_uri else ""
    habana_container_test_repo = "${HOME}/gaudi-test-suite:/gaudi-test-suite"
    execution_command = (
        f"docker run --user root "
        f"-e LOG_FILE={os.path.join(os.sep, 'test', 'benchmark', 'logs', log_name)} "
        f"-e PR_CONTEXT={1 if is_pr_context() else 0} "
        f"{container_runtime} {ompi_mca_btl} {hpu_env_vars} {cap_add} {ipc} "
        f"-v {container_test_local_dir}:{os.path.join(os.sep, 'test')} -v {habana_container_test_repo} "
        f"{ecr_uri} {os.path.join(os.sep, 'bin', 'bash')} -c '{test_cmd}'"
    )

    framework = (
        "tensorflow" if "tensorflow" in ecr_uri else "pytorch" if "pytorch" in ecr_uri else None
    )
    account_id_prefix = os.getenv(
        "ACCOUNT_ID", boto3.client("sts").get_caller_identity()["Account"]
    )[:3]
    s3_bucket_for_permanent_logs = f"dlinfra-habana-tests-{account_id_prefix}"
    test_type = "benchmark"
    custom_filename = test_cmd.split(f"{os.sep}")[-1]
    custom_filename += f"-cards-{cards_num}" if cards_num else "-cards-0"
    s3_uri_permanent_logs = get_s3_uri_for_saving_permanent_logs(
        framework,
        s3_bucket=s3_bucket_for_permanent_logs,
        test_type=test_type,
        custom_filename=custom_filename,
    )
    required_log_ending = "Kudos!! Habana tests executed successfully"
    execute_asynchronus_testing_using_s3_bucket(
        connection,
        execution_command,
        timeout,
        required_log_ending,
        loop_time=4 * 3600,
        s3_uri_for_saving_permanent_logs=s3_uri_permanent_logs,
        hang_detection_window=15,
    )
    LOGGER.info(f"Uploaded logs at: {s3_uri_permanent_logs}")
    return


def execute_ec2_inference_performance_test(
    connection,
    ecr_uri,
    test_cmd,
    region=DEFAULT_REGION,
    post_process=None,
    data_source="",
    threshold=None,
):
    docker_runtime = "--runtime=nvidia --gpus all" if "gpu" in ecr_uri else ""
    container_test_local_dir = os.path.join("$HOME", "container_tests")
    timestamp = time.strftime("%Y-%m-%d-%H-%M-%S")
    log_name = (
        f"{data_source}_results_{os.getenv('CODEBUILD_RESOLVED_SOURCE_VERSION')}_{timestamp}.txt"
    )
    # Make sure we are logged into ECR so we can pull the image
    account_id = get_account_id_from_image_uri(ecr_uri)
    login_to_ecr_registry(connection, account_id, region)
    connection.run(f"docker pull -q {ecr_uri}")

    # Run training command, display benchmark results to console
    repo_name, image_tag = ecr_uri.split("/")[-1].split(":")
    container_name = f"{repo_name}-performance-{image_tag}-ec2"
    connection.run(
        f"docker run {docker_runtime} -d --name {container_name} "
        f"-e LOG_FILE={os.path.join(os.sep, 'test', 'benchmark', 'logs', log_name)} "
        f"-v {container_test_local_dir}:{os.path.join(os.sep, 'test')} {ecr_uri}"
    )
    try:
        connection.run(
            f"docker exec --user root {container_name} "
            f"{os.path.join(os.sep, 'bin', 'bash')} -c {test_cmd}"
        )
    except Exception as e:
        raise Exception("Failed to exec benchmark command.\n", e)
    finally:
        connection.run(f"docker rm -f {container_name}")
    log_location = os.path.join(container_test_local_dir, "benchmark", "logs", log_name)
    ec2_performance_upload_result_to_s3_and_validate(
        connection,
        ecr_uri,
        log_location,
        data_source,
        threshold,
        post_process,
        log_name,
    )


def ec2_performance_upload_result_to_s3_and_validate(
    connection,
    ecr_uri,
    log_location,
    data_source,
    threshold,
    post_process,
    log_name,
    instance_type=None,
):
    framework = (
        "tensorflow" if "tensorflow" in ecr_uri else "mxnet" if "mxnet" in ecr_uri else "pytorch"
    )
    framework_version = re.search(r"\d+(\.\d+){2}", ecr_uri).group()
    py_version = "py2" if "py2" in ecr_uri else "py37" if "py37" in ecr_uri else "py3"
    processor = "gpu" if "gpu" in ecr_uri else "cpu"
    work_type = "training" if "training" in ecr_uri else "inference"
    s3_location = os.path.join(
        BENCHMARK_RESULTS_S3_BUCKET,
        framework,
        framework_version,
        "ec2",
        work_type,
        processor,
        py_version,
        log_name,
    )
    params = {"connection": connection, "log_location": log_location}
    if "threshold" in signature(post_process).parameters:
        params["threshold"] = threshold
    performance_number = post_process(**params)
    unit = (
        "s"
        if work_type == "inference" and framework == "tensorflow"
        else (
            "ms"
            if work_type == "inference" and framework == "pytorch"
            else (
                "s/epoch"
                if work_type == "training" and framework == "pytorch" and data_source == "imagenet"
                else "images/sec"
            )
        )
    )
    description = "p99 latency " if unit == "s" or unit == "ms" else ""
    for k, v in performance_number.items():
        performance_statement = (
            f"{framework} {framework_version} ec2 {work_type} {processor} {py_version} "
            f"{instance_type if instance_type else ''} {data_source} {k} {description}: {v} {unit}, threshold: {threshold[k]} {unit}"
        )
        connection.run(f"echo {performance_statement} | sudo tee -a {log_location}")
        LOGGER.info(f"{performance_statement}")
    connection.run(f"aws s3 cp {log_location} {s3_location}")
    LOGGER.info(f"To retrieve complete benchmark log, check {s3_location}")

    def _assertion_results():
        if "Cost" in performance_number:
            return performance_number["Cost"] < threshold["Cost"]
        if "Throughput" in performance_number:
            return performance_number["Throughput"] > threshold["Throughput"]
        if len(performance_number) == 0:
            return False
        failure_count = 0
        for k, v in performance_number.items():
            if v > threshold[k]:
                failure_count += 1
        return failure_count <= 2

    for _ in performance_number:
        assert _assertion_results(), (
            f"{framework} {framework_version} ec2 {work_type} {processor} {py_version} {data_source} "
            f"Benchmark Result {performance_number} does not reach the threshold {threshold}"
        )


def post_process_inference(connection, log_location, threshold):
    log_content = connection.run(f"cat {log_location}").stdout.split("\n")
    performance_number = {}
    for line in log_content:
        if "p99" in line:
            for key in threshold.keys():
                if key in line:
                    performance_number[key] = float(
                        re.search(
                            r"(p99[ ]*(Latency)?[ ]*:[ ]*)(?P<result>[0-9]+\.?[0-9]+)",
                            line,
                        ).group("result")
                    )
                    break
    return performance_number


def post_process_mxnet_ec2_performance(connection, log_location):
    log_content = connection.run(f"cat {log_location}").stdout.split("\n")
    total = 0.0
    n = 0
    for line in log_content:
        if "samples/sec" in line and "warmup" not in line:
            throughput = re.search(r"((?P<throughput>[0-9]+\.?[0-9]+)[ ]+samples/sec)", line).group(
                "throughput"
            )
            total += float(throughput)
            n += 1
    if total and n:
        return {"Throughput": total / n}
    else:
        raise ValueError("total: {}; n: {} -- something went wrong".format(total, n))


def install_python_in_instance(context, python_version="3.9"):
    """
    Install python on DLAMI EC2 instances to create a consistent test environment that is agnostic to AMI used for test.
    This helper function assumes that the EC2 instance uses a DLAMI. The /etc/profile.d/dlami.sh file doesn't exist
    in other AMIs. If support for other AMIs is needed, this function will need to be updated.
    :param context: Invoke Context / Fabric Connection object
    :param python_version: str python version to install, such as 3.8, 3.9, etc.
    :return: None
    """
    if context.run("pyenv --version", warn=True, hide=True).failed:
        context.run(
            """ls ~/.pyenv || git clone https://github.com/pyenv/pyenv.git ~/.pyenv""", hide=True
        )

        # for images that do not have /etc/profile.d/dlami.sh, we will make it here
        if context.run("test -f /etc/profile.d/dlami.sh", warn=True, hide=True).failed:
            LOGGER.info("/etc/profile.d/dlami.sh does not exist. Making...")
            context.run("sudo touch /etc/profile.d/dlami.sh")
            LOGGER.info("adding /etc/profile.d/dlami.sh to .bashrc")
            context.run(
                """echo '[ -z "$PS1" ] && source /etc/profile.d/dlami.sh'|cat - ~/.bashrc > ~/temprc """
                """&& mv ~/temprc ~/.bashrc""",
                hide=True,
            )

        context.run("sudo chmod 666 /etc/profile.d/dlami.sh", hide=True)
        context.run(
            """echo 'export PYENV_ROOT="$HOME/.pyenv"' >> /etc/profile.d/dlami.sh""", hide=True
        )
        context.run(
            """echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> /etc/profile.d/dlami.sh""",
            hide=True,
        )
        context.run("""echo 'eval "$(pyenv init -)"' >> /etc/profile.d/dlami.sh""", hide=True)
        context.run("sudo chmod 644 /etc/profile.d/dlami.sh", hide=True)
    context.run("sudo dnf update -y", hide=True)
    context.run(
        (
            "sudo dnf install -y make gcc gcc-c++ openssl-devel zlib-devel "
            "bzip2-devel readline-devel sqlite-devel llvm "
            "ncurses-devel xz tk-devel libxml2-devel xmlsec1-devel libffi-devel xz-devel --skip-broken"
        ),
        hide=True,
    )

    context.run(f"pyenv install {python_version}", hide=True)
    context.run(f"pyenv global {python_version}", hide=True)

    # Validate that installed python version is the same as requested python version
    python_version_response = context.run("python --version", hide=True)
    python_version_match = re.search(r"Python (\d+(\.\d+)+)", python_version_response.stdout)
    assert python_version_match, "Running 'python --version' returned None"
    installed_python_version = python_version_match.group(1)
    # Use SpecifierSet("=={python_version}.*") to accommodate python_version of the form X.Y as well as X.Y.Z
    assert Version(installed_python_version) in SpecifierSet(
        f"=={python_version}.*"
    ), f"Installed python version {installed_python_version} does not match required python_version {python_version}"


def get_availability_zone_ids(ec2_client):
    """
    Obtain list of AZs in a particular region using ec2_client
    :param ec2_client: boto3 EC2 Client object
    :return: list of str AZ names
    """
    response = ec2_client.describe_availability_zones()
    return [az["ZoneName"] for az in response["AvailabilityZones"]]


def get_default_vpc_id(ec2_client):
    """
    Get vpd-id of default VPC in a particular region using ec2_client in that region
    :param ec2_client: boto3 EC2 Client object
    :return: str Default vpc-id
    """
    response = ec2_client.describe_vpcs(Filters=[{"Name": "is-default", "Values": ["true"]}])
    default_vpc_id = response["Vpcs"][0]["VpcId"]
    return default_vpc_id


def get_default_security_group_id(ec2_client):
    """
    Get security-group-id of default SG on the default VPC in a particular region using ec2_client
    :param ec2_client: boto3 EC2 Client object
    :return: str Default security-group-id
    """
    default_vpc_id = get_default_vpc_id(ec2_client)
    response = ec2_client.describe_security_groups(
        GroupNames=["default"],
        Filters=[{"Name": "vpc-id", "Values": [default_vpc_id]}],
    )
    default_security_group_id = response["SecurityGroups"][0]["GroupId"]
    return default_security_group_id


def get_efa_enabled_security_group_id(ec2_client):
    """
    Get security-group-id of custom EFA-enabled SG in the default VPC in a particular region
    :param ec2_client: boto3 EC2 Client object
    :return: str security-group-id of SG named "EFA-enabled"
    """
    default_vpc_id = get_default_vpc_id(ec2_client)
    response = ec2_client.describe_security_groups(
        GroupNames=["EFA-enabled"],
        Filters=[{"Name": "vpc-id", "Values": [default_vpc_id]}],
    )

    efa_security_group_id = response["SecurityGroups"][0]["GroupId"]
    return efa_security_group_id


def get_default_subnet_for_az(ec2_client, availability_zone):
    """
    Get subnet-id associated with a particular AZ using ec2_client for that region
    :param ec2_client: boto3 EC2 Client object
    :param availability_zone: str Availability Zone name
    :return: str subnet-id
    """
    response = ec2_client.describe_subnets(
        Filters=[
            {"Name": "availability-zone", "Values": [availability_zone]},
            {"Name": "default-for-az", "Values": ["true"]},
        ]
    )
    az_subnet_id = response["Subnets"][0]["SubnetId"]
    return az_subnet_id


def get_subnet_id_by_vpc(ec2_client, vpc_id):

    response = ec2_client.describe_subnets(
        Filters=[
            {
                "Name": "vpc-id",
                "Values": [
                    vpc_id,
                ],
            },
        ],
    )

    subnet_ids = []
    for subnet in response["Subnets"]:
        if subnet["SubnetId"] is not None:
            subnet_ids.append(subnet["SubnetId"])

    return subnet_ids


def get_vpc_id_by_name(ec2_client, vpc_name):
    """
    Get VPC ID by VPC name tag
    :param ec2_client: boto3 EC2 Client object
    :param vpc_name: Name tag value of the VPC
    :return: str VPC ID of the VPC name
    """
    response = ec2_client.describe_vpcs(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).get(
        "Vpcs", []
    )

    if not response:
        raise Exception(f"No VPC found with Name tag: {vpc_name}")
    elif len(response) > 1:
        raise Exception(f"Multiple VPCs found with Name tag: {vpc_name}")

    vpc_id = response[0]["VpcId"]

    return vpc_id


def get_default_security_group_id_by_vpc_id(ec2_client, vpc_name):
    """
    Get default SG ID for a non-default VPC
    :param ec2_client: boto3 EC2 Client object
    :param vpc_name: Name tag value of the VPC
    :return: str SG ID of the default SG
    """
    try:
        vpc_id = get_vpc_id_by_name(ec2_client, vpc_name)

        response = ec2_client.describe_security_groups(
            Filters=[
                {"Name": "vpc-id", "Values": [vpc_id]},
                {"Name": "group-name", "Values": ["default"]},
            ],
        )

        security_group_id = response["SecurityGroups"][0]["GroupId"]
        return security_group_id
    except Exception as e:
        LOGGER.error(f"Error in get_default_security_group_id_by_vpc_id: {str(e)}")
        raise


def get_ipv6_efa_enabled_security_group_id(ec2_client, vpc_name):
    """
    Get EFA-enabled SG ID for IPv6 VPC by identifying security groups that allow
    all traffic within themselves
    :param ec2_client: boto3 EC2 Client object
    :param vpc_name: Name tag value of the VPC
    :return: str SG ID of the EFA-enabled SG
    """
    try:
        vpc_id = get_vpc_id_by_name(ec2_client, vpc_name)

        response = ec2_client.describe_security_groups(
            Filters=[
                {"Name": "vpc-id", "Values": [vpc_id]},
            ]
        )

        for sg in response["SecurityGroups"]:
            inbound_all_traffic = any(
                rule["IpProtocol"] == "-1"
                and any(
                    pair["GroupId"] == sg["GroupId"] for pair in rule.get("UserIdGroupPairs", [])
                )
                for rule in sg["IpPermissions"]
            )

            outbound_all_traffic = any(
                rule["IpProtocol"] == "-1"
                and any(
                    pair["GroupId"] == sg["GroupId"] for pair in rule.get("UserIdGroupPairs", [])
                )
                for rule in sg["IpPermissionsEgress"]
            )

            if inbound_all_traffic and outbound_all_traffic:
                return sg["GroupId"]

        raise ValueError(
            f"No EFA-enabled security group found in VPC {vpc_name}. Expected a sg that allows all traffic to and from itself."
        )
    except Exception as e:
        LOGGER.error(f"Error when getting IPv6 EFA-enabled sg id: {str(e)}")
        raise


def get_ipv6_enabled_subnet_for_az(ec2_client, vpc_name, availability_zone):
    """
    Get IPv6-enabled subnet ID in the a particular availability zone
    :param ec2_client: boto3 EC2 Client object
    :param vpc_name: Name tag value of the VPC
    :param availability_zone: str AZ name
    :return: str Subnet ID of an IPv6-enabled subnet
    """
    try:
        vpc_id = get_vpc_id_by_name(ec2_client, vpc_name)

        route_tables = ec2_client.describe_route_tables(
            Filters=[{"Name": "vpc-id", "Values": [vpc_id]}]
        )["RouteTables"]

        response = ec2_client.describe_subnets(
            Filters=[
                {"Name": "vpc-id", "Values": [vpc_id]},
                {"Name": "availability-zone", "Values": [availability_zone]},
            ]
        )

        ipv6_subnets = [
            subnet
            for subnet in response["Subnets"]
            if subnet.get("Ipv6CidrBlockAssociationSet")
            and is_public_subnet(subnet["SubnetId"], route_tables)
        ]

        if not ipv6_subnets:
            raise Exception(
                f"No IPv6-enabled subnet found in AZ {availability_zone} for VPC {vpc_id}"
            )

        return ipv6_subnets[0]["SubnetId"]
    except Exception as e:
        LOGGER.error(
            f"Error in when getting IPv6 enabled subnet for AZ {availability_zone}: {str(e)}"
        )
        raise


def is_public_subnet(subnet_id, route_tables):
    """
    Check if a subnet is public by verifying if it has a route table with an Internet Gateway
    that routes all IPv4 or IPv6 traffic
    :param subnet_id: str the subnet ID to check
    :param route_tables: list route tables from the VPC
    :return: True if subnet is public, False otherwise
    """
    for route_table in route_tables:
        has_igw = False
        for route in route_table.get("Routes", []):
            if route.get("GatewayId", "").startswith("igw-"):
                if (
                    route.get("DestinationCidrBlock") == "0.0.0.0/0"
                    or route.get("DestinationIpv6CidrBlock") == "::/0"
                ):
                    has_igw = True
                    break
        if not has_igw:
            continue

        # check if subnet is associated with route table
        for association in route_table.get("Associations", []):
            if association.get("SubnetId") == subnet_id:
                return True

    return False


def generate_standard_dual_stack_network_interface(ec2_client, availability_zone):
    """
    Generate network interface configuration for dual-stack (IPv4/IPv6) instances.
    :param ec2_client: boto3 EC2 Client
    :param availability_zone: str AZ in which the instance must be created
    :return: list containing a single network interface configuration for dual-stack
    """
    try:
        if not IPV6_VPC_NAME:
            raise ValueError("IPv6 VPC name is not set")

        ipv6_default_sg = get_default_security_group_id_by_vpc_id(ec2_client, IPV6_VPC_NAME)
        ipv6_subnet_id = get_ipv6_enabled_subnet_for_az(
            ec2_client, IPV6_VPC_NAME, availability_zone
        )

        network_interfaces = [
            {
                "DeviceIndex": 0,
                "DeleteOnTermination": True,
                "Groups": [ipv6_default_sg],
                "SubnetId": ipv6_subnet_id,
                "Ipv6AddressCount": 1,
            }
        ]

        return network_interfaces

    except Exception as e:
        LOGGER.error(
            f"Failed to generate dual-stack network interface in AZ {availability_zone}: {str(e)}"
        )
        raise


def generate_network_interfaces(ec2_client, ec2_instance_type, availability_zone):
    """
    Generate list of EFA-network-interfaces based on the number of network-interfaces available
    on a given instance type.
    :param ec2_client: boto3 EC2 Client
    :param ec2_instance_type: str EC2 Instance Type with network interface to be configured
    :param availability_zone: str AZ in which the instance must be created
    :return: list of dicts mapping each network-interface available
    """
    num_efa_interfaces = get_num_efa_interfaces_for_instance_type(ec2_instance_type)
    if not num_efa_interfaces:
        raise AttributeError(f"Unable to get number of EFA Interfaces for {ec2_instance_type}")

    if ENABLE_IPV6_TESTING:
        vpc_name = IPV6_VPC_NAME
        efa_sg = get_ipv6_efa_enabled_security_group_id(ec2_client, vpc_name)
        sg_ids = [efa_sg]
        subnet_id = get_ipv6_enabled_subnet_for_az(ec2_client, vpc_name, availability_zone)
    else:
        default_sg = get_default_security_group_id(ec2_client)
        efa_sg = get_efa_enabled_security_group_id(ec2_client)
        sg_ids = [default_sg, efa_sg]
        subnet_id = get_default_subnet_for_az(ec2_client, availability_zone)

    network_interfaces = []
    for i in range(num_efa_interfaces):
        interface = {
            "DeviceIndex": 0 if i == 0 else 1,
            "NetworkCardIndex": i,
            "DeleteOnTermination": True,
            "InterfaceType": "efa",
            "Groups": sg_ids,
            "SubnetId": subnet_id,
        }

        network_interfaces.append(interface)

    return network_interfaces


def get_network_interface_id(instance_id, region=DEFAULT_REGION):
    """
    Gets the network interface at index 0 from the instance_id. Meant to be used
    with p4d instance with 4 efa devices
    """
    instance = get_instance_from_id(instance_id, region)
    network_interfaces_info = instance["NetworkInterfaces"]
    for device in network_interfaces_info:
        if device["Attachment"]["DeviceIndex"] == 0:
            return device["NetworkInterfaceId"]

    raise Exception("Could not find network device 0, retry operation")


def get_ipv6_address_for_eth0(instance_id, region=DEFAULT_REGION):
    """
    Gets the IPv6 address specifically from eth0 (Device Index 0) of an EC2 instance
    """
    instance = get_instance_from_id(instance_id, region)
    network_interfaces_info = instance["NetworkInterfaces"]
    for device in network_interfaces_info:
        if device["Attachment"]["DeviceIndex"] == 0:
            if device["Ipv6Addresses"]:
                return device["Ipv6Addresses"][0]["Ipv6Address"]
            LOGGER.info(f"No IPv6 address found on eth0 for instance {instance_id}")
            return None

    LOGGER.error(f"Could not find eth0 for instance {instance_id}")
    return None


def attach_elastic_ip(network_interface_id, region="us-east-1", is_ipv6=False):
    """
    Creates and attaches an elastic ip to a network interface which is already
    attached to an efa enabled device. This is needed specifically for 4 efa devices
    attached to a p4d instance. Having multiple network devices prevents automatic
    public ip address assignment, so we must do it manually.
    """
    ec2_client = boto3.client("ec2", region_name=region)
    arguments_dict = {
        "Domain": "vpc",
        "TagSpecifications": [
            {
                "ResourceType": "elastic-ip",
                "Tags": [{"Key": "Name", "Value": f"elastic_ip_{network_interface_id}"}],
            }
        ],
    }
    elastic_ip = ec2_client.allocate_address(**arguments_dict)
    elastic_ip_allocation_id = elastic_ip["AllocationId"]
    response = ec2_client.associate_address(
        AllocationId=elastic_ip_allocation_id, NetworkInterfaceId=network_interface_id
    )
    if is_ipv6:
        ec2_client.assign_ipv6_addresses(
            NetworkInterfaceId=network_interface_id, Ipv6AddressCount=1
        )
    return elastic_ip_allocation_id


def delete_elastic_ips(elastic_ip_allocation_ids, ec2_client):
    """
    Deletes elastic ips created for efa p4d testing.
    For default VPC (IPv4): can release directly
    For non-default VPC (IPv6): need to disassociate before release
    """
    for allocation_id in elastic_ip_allocation_ids:
        try:
            if ENABLE_IPV6_TESTING:
                address = ec2_client.describe_addresses(AllocationIds=[allocation_id])["Addresses"][
                    0
                ]
                if "AssociationId" in address:
                    ec2_client.disassociate_address(AssociationId=address["AssociationId"])
                    time.sleep(10)
            ec2_client.release_address(AllocationId=allocation_id)
        except Exception as e:
            LOGGER.error(f"Failed to delete elastic ip {allocation_id}: {str(e)}")


def create_name_tags_for_instance(instance_id, name_tag, region):
    """
    Create name tags for an instance
    :param instance_id: str Instance ID on which to apply the given Name tag
    :param name_tag: str Name tag to be applied
    :param region: str Region in which instance is running
    """
    ec2_client = boto3.client("ec2", region_name=region)
    response = ec2_client.create_tags(
        Resources=[instance_id],
        Tags=[{"Key": "Name", "Value": name_tag}],
    )
    if not response:
        raise Exception(
            "Unable to create name tag {0} for the  instance {1}".format(name_tag, instance_id)
        )


def get_efa_devices_on_instance(connection):
    """
    Get list of EFA devices available for use in an instance
    :param connection: Fabric Connection object
    :return: list of str device paths
    """
    response = connection.run("ls /dev/infiniband/uverbs*")
    devices = response.stdout.split()
    return devices
